#!/usr/bin/python3
# ******************************************************************************
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
# licensed under the Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#     http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN 'AS IS' BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
# PURPOSE.
# See the Mulan PSL v2 for more details.
# ******************************************************************************/
import re
from collections import defaultdict
from typing import List, Tuple

from ceres.conf.constant import CommandExitCode, TaskExecuteRes
from ceres.function.log import LOGGER
from ceres.function.util import execute_shell_command


class HotpatchRemoveManage:
    def remove_hotpatch(self, cves: List[str]) -> dict:
        """
        remove hotpatch

        Args:
            cves(list): List of CVE IDs fixed by hotpatch,e.g.
                ["CVE-XXXX-XXXX"]

        Returns:
            dict e.g
                {
                    "status": "fail/succeed",
                    "cves": [{
                        "cve_id": cve,
                        "result": "succeed",
                        "log": "rollback succeed"
                    }]
                }
        """
        hotpatch_list = self._hotpatch_list_cve()
        if not hotpatch_list:
            return {
                "status": TaskExecuteRes.FAIL,
                "cves": [
                    dict(cve_id=cve, log="No valid hotpatch is matched.", result=TaskExecuteRes.FAIL) for cve in cves
                ],
            }

        wait_to_remove_patch = set()
        for cve in cves:
            wait_to_remove_patch = wait_to_remove_patch.union(hotpatch_list.get(cve, set()))

        hotpatch_remove_res = {}
        for patch in set(wait_to_remove_patch):
            remove_result, log = self._hotpatch_remove(patch)
            hotpatch_remove_res[patch] = {
                "result": TaskExecuteRes.SUCCEED if remove_result else TaskExecuteRes.FAIL,
                "log": log,
            }

        cve_hotpatch_remove_result = []

        for cve in cves:
            if cve not in hotpatch_list:
                fail_result = {
                    "cve_id": cve,
                    "log": "No valid hot patch is matched.",
                    "result": TaskExecuteRes.FAIL,
                }
                cve_hotpatch_remove_result.append(fail_result)
            else:
                tmp_result_list = []
                tmp_log = []

                for patch in hotpatch_list.get(cve):
                    tmp_result_list.append(hotpatch_remove_res[patch]["result"] == TaskExecuteRes.SUCCEED)
                    tmp_log.append(hotpatch_remove_res[patch]["log"])

                cve_hotpatch_remove_result.append(
                    {
                        "cve_id": cve,
                        "log": "\n".join(tmp_log),
                        "result": TaskExecuteRes.SUCCEED if all(tmp_result_list) else TaskExecuteRes.FAIL,
                    }
                )

        return {"status": TaskExecuteRes.SUCCEED, "cves": cve_hotpatch_remove_result}

    @staticmethod
    def _hotpatch_list_cve() -> dict:
        """
        Run the dnf hotpatch list cve command to query the hotpatch list corresponding to the cve

        Returns:
            dict: e.g
                {
                    "CVE-XXXX-XXX": {"patch 1", "patch 2"}
                }
        """
        code, stdout, _ = execute_shell_command([f"dnf hot-updateinfo list cves --installed", "grep patch"])
        if code != CommandExitCode.SUCCEED:
            LOGGER.error(f"Failed to query the hotpatch list.")
            return None

        all_cve_info = re.findall(r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(\S+|-)\s+(patch\S+)", stdout)
        if not all_cve_info:
            LOGGER.error(f"Failed to query the hotpatch list.")
            return None

        applied_hotpatch_info = {}
        hotpatch_dic = {}
        for cve_id, _, _, hotpatch in all_cve_info:
            applied_hotpatch_info[cve_id] = hotpatch
            hotpatch_dic_key = hotpatch.rsplit("-", 2)[0]
            if hotpatch_dic_key.endswith("ACC"):
                hotpatch_dic[hotpatch_dic_key] = max(hotpatch, hotpatch_dic.get(hotpatch_dic_key, hotpatch))

        for cve_id, cmd_output_hotpatch in applied_hotpatch_info.items():
            applied_hotpatch_info[cve_id] = hotpatch_dic.get(cmd_output_hotpatch.rsplit("-", 2)[0], cmd_output_hotpatch)

        hotpatch_list = defaultdict(set)
        for cve_id, hotpatch in applied_hotpatch_info.items():
            hotpatch_list[cve_id].add(hotpatch)

        return hotpatch_list

    def _hotpatch_remove(self, hotpatch: str) -> Tuple[bool, str]:
        """
        remove hotpatch package

        Args:
            hotpatch: hotpatch package which needs to remove
        """
        cmd = [f"dnf remove {hotpatch} -y"]
        _, stdout, stderr = execute_shell_command(cmd)
        return True, f"Command:{cmd}\n\n{stdout}\n{stderr}\n"
